Strassen’s Matrix multiplication
Video Lecture
Brute Force Approach
Brute Force Algorithm
- Initialization: The function
void multiply(int A[][N], int B[][N], int C[][N])takes threeN x NmatricesA,B, andC. The goal is to compute the product of matricesAandB, storing the result in matrixC. - Nested Loops: The algorithm uses three nested loops:
- Outer loop (
i): Iterates over the rows ofA. - Middle loop (
j): Iterates over the columns ofB. - Inner loop (
k): Computes the dot product of thei-th row ofAand thej-th column ofB.
- Outer loop (
- Element Computation: For each element
C[i][j]:
- Initialize
C[i][j]to 0. - Update it with the sum of products
A[i][k] * B[k][j]for allkfrom 0 toN-1.
- Initialize
- Result: The final matrix
Ccontains the product of matricesAandB. - Time Complexity: The time complexity of this brute force approach is
O(N3), involvingN3multiplications andN3additions.
Divide and Conquer Approach
Input: Matrices A and B are given square matrices with size n x n, where n is a power of 2 (e.g., 1x1, 2x2, 4x4, 8x8, 16x16, 32x32, etc.).
Output: The resultant matrix C = A * B, where C is a square matrix of size n x n.
Process:
If n = 2, apply Strassen's Matrix Multiplication Algorithm to compute the elements of matrix C:
M1 = (A11 + A22) * (B11 + B22)
M2 = (A21 + A22) * B11
M3 = A11 * (B12 - B22)
M4 = A22 * (B21 - B11)
M5 = (A11 + A12) * B22
M6 = (A21 - A11) * (B11 + B12)
M7 = (A12 - A22) * (B21 + B22)
Then, compute the submatrices of C:
C11 = M1 + M4 - M5 + M7
C12 = M3 + M5
C21 = M2 + M4
C22 = M1 - M2 + M3 + M6
If n > 2, apply the Divide and Conquer method:
Divide matrices A and B into 8 submatrices, each of size n/2 x n/2.
Recursively compute the submatrices of C by multiplying the corresponding submatrices of A and B.
Result: Combine the computed submatrices to form the final matrix C.
Performance: The Divide and Conquer approach is faster than the standard brute force matrix multiplication algorithm for large matrices, while the brute force approach is more efficient for small matrices.
Strassen’s Matrix Multiplication Algorithm
- Base Case (n = 1):
- If
n = 1, then the result matrixC11is calculated asC11 = A11 * B11.
- If
- Case for n = 2:
- If
n = 2, compute the elements of the result matrix as:
C11 = M1 + M4 − M5 + M7C12 = M3 + M5C21 = M2 + M4C22 = M1 − M2 + M3 + M6
- If
- Recursive Case (n > 2):
- If
n > 2, divide the matrices into submatrices and recursively compute the submatrices as follows:
SMM(A11, B11, n/2) * SMM(A12, B21, n/2)SMM(A11, B12, n/2) * SMM(A12, B22, n/2)SMM(A21, B11, n/2) * SMM(A22, B21, n/2)SMM(A21, B12, n/2) * SMM(A22, B22, n/2)
- If
- End: The result matrix is computed based on the above cases.
Strassen's Matrix Multiplication Code
#include <stdio.h>
#include<stdio.h>
int main() {
int a[2][2], b[2][2], c[2][2], i, j;
int m1, m2, m3, m4, m5, m6, m7;
// Input first matrix
printf("Enter the 4 elements of the first matrix: ");
for (i = 0; i < 2; i++) {
for (j = 0; j < 2; j++) {
scanf("%d", &a[i][j]);
}
}
// Input second matrix
printf("Enter the 4 elements of the second matrix: ");
for (i = 0; i < 2; i++) {
for (j = 0; j < 2; j++) {
scanf("%d", &b[i][j]);
}
}
// Display first matrix
printf("\nThe first matrix is\n");
for (i = 0; i < 2; i++) {
for (j = 0; j < 2; j++) {
printf("%d\t", a[i][j]);
}
printf("\n");
}
// Display second matrix
printf("\nThe second matrix is\n");
for (i = 0; i < 2; i++) {
for (j = 0; j < 2; j++) {
printf("%d\t", b[i][j]);
}
printf("\n");
}
// Strassen's formula calculations
m1 = (a[0][0] + a[1][1]) * (b[0][0] + b[1][1]);
m2 = (a[1][0] + a[1][1]) * b[0][0];
m3 = a[0][0] * (b[0][1] - b[1][1]);
m4 = a[1][1] * (b[1][0] - b[0][0]);
m5 = (a[0][0] + a[0][1]) * b[1][1];
m6 = (a[1][0] - a[0][0]) * (b[0][0] + b[0][1]);
m7 = (a[0][1] - a[1][1]) * (b[1][0] + b[1][1]);
// Calculating elements of result matrix c
c[0][0] = m1 + m4 - m5 + m7;
c[0][1] = m3 + m5;
c[1][0] = m2 + m4;
c[1][1] = m1 - m2 + m3 + m6;
// Display result matrix
printf("\nThe resultant matrix after multiplication is\n");
for (i = 0; i < 2; i++) {
for (j = 0; j < 2; j++) {
printf("%d\t", c[i][j]);
}
printf("\n");
}
return 0;
}
#include <iostream>
#include <vector>
using namespace std;
typedef vector<vector<int>> Matrix;
// Function to add two matrices
Matrix add(const Matrix &A, const Matrix &B) {
int n = A.size();
Matrix C(n, vector<int>(n));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// Function to subtract two matrices
Matrix subtract(const Matrix &A, const Matrix &B) {
int n = A.size();
Matrix C(n, vector<int>(n));
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
// Function to multiply two matrices using Strassen's algorithm
Matrix strassen(const Matrix &A, const Matrix &B) {
int n = A.size();
if (n == 1) {
Matrix C(1, vector<int>(1));
C[0][0] = A[0][0] * B[0][0];
return C;
}
int newSize = n / 2;
Matrix A11(newSize, vector<int>(newSize));
Matrix A12(newSize, vector<int>(newSize));
Matrix A21(newSize, vector<int>(newSize));
Matrix A22(newSize, vector<int>(newSize));
Matrix B11(newSize, vector<int>(newSize));
Matrix B12(newSize, vector<int>(newSize));
Matrix B21(newSize, vector<int>(newSize));
Matrix B22(newSize, vector<int>(newSize));
// Dividing matrices into 4 sub-matrices
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + newSize];
A21[i][j] = A[i + newSize][j];
A22[i][j] = A[i + newSize][j + newSize];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j + newSize];
B21[i][j] = B[i + newSize][j];
B22[i][j] = B[i + newSize][j + newSize];
}
}
Matrix M1 = strassen(add(A11, A22), add(B11, B22));
Matrix M2 = strassen(add(A21, A22), B11);
Matrix M3 = strassen(A11, subtract(B12, B22));
Matrix M4 = strassen(A22, subtract(B21, B11));
Matrix M5 = strassen(add(A11, A12), B22);
Matrix M6 = strassen(subtract(A21, A11), add(B11, B12));
Matrix M7 = strassen(subtract(A12, A22), add(B21, B22));
Matrix C11 = add(subtract(add(M1, M4), M5), M7);
Matrix C12 = add(M3, M5);
Matrix C21 = add(M2, M4);
Matrix C22 = add(subtract(add(M1, M3), M2), M6);
Matrix C(n, vector<int>(n));
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
C[i][j] = C11[i][j];
C[i][j + newSize] = C12[i][j];
C[i + newSize][j] = C21[i][j];
C[i + newSize][j + newSize] = C22[i][j];
}
}
return C;
}
void printMatrix(const Matrix &M) {
int n = M.size();
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
cout << M[i][j] << " ";
}
cout << endl;
}
}
int main() {
int n = 4; // Size of the matrix
Matrix A(n, vector<int>(n, 0));
Matrix B(n, vector<int>(n, 0));
// Fill matrices A and B with some values here
// Example:
// A[0][0] = 1; B[0][0] = 1;
Matrix C = strassen(A, B);
cout << "Product Matrix:" << endl;
printMatrix(C);
return 0;
}
public class StrassenMatrixMultiplication {
// Function to add two matrices
public static int[][] add(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] + B[i][j];
}
}
return C;
}
// Function to subtract two matrices
public static int[][] subtract(int[][] A, int[][] B) {
int n = A.length;
int[][] C = new int[n][n];
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
C[i][j] = A[i][j] - B[i][j];
}
}
return C;
}
// Function to multiply two matrices using Strassen's algorithm
public static int[][] strassen(int[][] A, int[][] B) {
int n = A.length;
if (n == 1) {
int[][] C = new int[1][1];
C[0][0] = A[0][0] * B[0][0];
return C;
}
int newSize = n / 2;
int[][] A11 = new int[newSize][newSize];
int[][] A12 = new int[newSize][newSize];
int[][] A21 = new int[newSize][newSize];
int[][] A22 = new int[newSize][newSize];
int[][] B11 = new int[newSize][newSize];
int[][] B12 = new int[newSize][newSize];
int[][] B21 = new int[newSize][newSize];
int[][] B22 = new int[newSize][newSize];
// Dividing matrices into 4 sub-matrices
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
A11[i][j] = A[i][j];
A12[i][j] = A[i][j + newSize];
A21[i][j] = A[i + newSize][j];
A22[i][j] = A[i + newSize][j + newSize];
B11[i][j] = B[i][j];
B12[i][j] = B[i][j + newSize];
B21[i][j] = B[i + newSize][j];
B22[i][j] = B[i + newSize][j + newSize];
}
}
int[][] M1 = strassen(add(A11, A22), add(B11, B22));
int[][] M2 = strassen(add(A21, A22), B11);
int[][] M3 = strassen(A11, subtract(B12, B22));
int[][] M4 = strassen(A22, subtract(B21, B11));
int[][] M5 = strassen(add(A11, A12), B22);
int[][] M6 = strassen(subtract(A21, A11), add(B11, B12));
int[][] M7 = strassen(subtract(A12, A22), add(B21, B22));
int[][] C11 = add(subtract(add(M1, M4), M5), M7);
int[][] C12 = add(M3, M5);
int[][] C21 = add(M2, M4);
int[][] C22 = add(subtract(add(M1, M3), M2), M6);
int[][] C = new int[n][n];
for (int i = 0; i < newSize; i++) {
for (int j = 0; j < newSize; j++) {
C[i][j] = C11[i][j];
C[i][j + newSize] = C12[i][j];
C[i + newSize][j] = C21[i][j];
C[i + newSize][j + newSize] = C22[i][j];
}
}
return C;
}
// Function to print the matrix
public static void printMatrix(int[][] matrix) {
int n = matrix.length;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
System.out.print(matrix[i][j] + " ");
}
System.out.println();
}
}
public static void main(String[] args) {
int n = 4; // Size of the matrix
int[][] A = new int[n][n];
int[][] B = new int[n][n];
// Fill matrices A and B with some values here
// Example:
// A[0][0] = 1; B[0][0] = 1;
int[][] C = strassen(A, B);
System.out.println("Product Matrix:");
printMatrix(C);
}
}
import numpy as np
def add(A, B):
return np.add(A, B)
def subtract(A, B):
return np.subtract(A, B)
def strassen(A, B):
n = A.shape[0]
if n == 1:
return A * B
new_size = n // 2
A11 = A[:new_size, :new_size]
A12 = A[:new_size, new_size:]
A21 = A[new_size:, :new_size]
A22 = A[new_size:, new_size:]
B11 = B[:new_size, :new_size]
B12 = B[:new_size, new_size:]
B21 = B[new_size:, :new_size]
B22 = B[new_size:, new_size:]
M1 = strassen(add(A11, A22), add(B11, B22))
M2 = strassen(add(A21, A22), B11)
M3 = strassen(A11, subtract(B12, B22))
M4 = strassen(A22, subtract(B21, B11))
M5 = strassen(add(A11, A12), B22)
M6 = strassen(subtract(A21, A11), add(B11, B12))
M7 = strassen(subtract(A12, A22), add(B21, B22))
C11 = add(subtract(add(M1, M4), M5), M7)
C12 = add(M3, M5)
C21 = add(M2, M4)
C22 = add(subtract(add(M1, M3), M2), M6)
C = np.zeros((n, n), dtype=A.dtype)
C[:new_size, :new_size] = C11
C[:new_size, new_size:] = C12
C[new_size:, :new_size] = C21
C[new_size:, new_size:] = C22
return C
# Example usage
A = np.array([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]])
B = np.array([[16, 15, 14, 13], [12, 11, 10, 9], [8, 7, 6, 5], [4, 3, 2, 1]])
C = strassen(A, B)
print("Product Matrix:")
print(C)
Analysis of Algorithm
Time Complexity of Divide and Conquer Approach
For n × n matrix multiplication using Strassen’s algorithm, the time complexity is derived based on the divide and conquer approach.
Key Points:
- Base Case: If
n = 1, the algorithm performs 1 multiplication. So,T(1) = 1. - Recursive Case: For
n = 2k, wherek = log n:
- The algorithm makes 7 recursive calls, each operating on a subproblem of size
n/2. - The recurrence relation is
T(n) = 7T(n/2) + cn2, wherecn2is the time complexity for matrix addition.
- The algorithm makes 7 recursive calls, each operating on a subproblem of size
Deriving the Time Complexity:
- Start with the recurrence relation:
T(n) = 7T(n/2)
- Expand it by substituting recursively:
T(n) = 7(7T(n/4)) = 72T(n/4)
T(n) = 7kT(n/2k)
- Since
n = 2kandT(1) = 1, the final expansion is:
T(n) = 7log n ⋅ T(1) = nlog2 7
Here,
log2 7 ≈ 2.808.
Conclusion:
The time complexity of Strassen’s algorithm is:
T(n) = O(n2.808)This is an improvement over the standard matrix multiplication time complexity of O(n3).
Time Complexity Analysis (Expanded)
- Recurrence Relation:
T(n) = 7T(n/2) + cn2
where
cn2accounts for the matrix addition operations.
- Master Theorem Application:
Using the Master Theorem, the time complexity derived is
O(nlog2 7) = O(n2.808).
Summary:
Strassen’s algorithm reduces the time complexity of matrix multiplication from O(n3) to O(n2.808), making it more efficient for large matrices.